/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package tensorflow.dtensor;

// Defines a sharding spec over a tensor.
// A sharding spec can either be a mesh dimension of the associated mesh or the
// special values scalar, or not_sharded.
message ShardingSpec {
  // Field 1 was deleted, therefore we reserve it for back-compatibility.
  reserved 1;
  // If set, the mesh dimension this tensor dimension is sharded over.
  string sharding_spec = 2;
}

message MeshDimensionProto {
  // Name of mesh dimension, required to create Mesh.
  string name = 1;
  // Length of mesh dimension.
  int64 size = 2;
}

// Proto representation of a Layout.
message LayoutProto {
  repeated ShardingSpec sharding_specs = 1;
  MeshProto mesh_config = 2;
  enum LayoutType {
    UNKNOWN = 0;
    STATIC = 1;  // A tiled layout for static and evenly shaped local shards.
    PARTED = 2;  // A parted layout contains axes that are treated as
                 // independent by most of SPMD expanders.
                 // FIXME(b/285905569): The exact semantics is still being
                 // investigated.

    SINGLE_DEVICE = 3;  // A single device layout that represents DTensor on a
                        // single device Mesh.
  }
  LayoutType type = 3;
}

// Proto representation of a Mesh.
// TODO(allenl): Add a unique mesh ID so we have an efficient way to identify
// meshes in mappings.
message MeshProto {
  repeated MeshDimensionProto mesh_dimensions = 1;

  // A list of global devices ids.
  repeated int64 global_device_ids = 2;

  // Devices ids handled by the local process.
  repeated int64 local_device_ids = 4;

  // Devices handled by the local process.
  repeated string local_devices = 5;

  // Global device names (Optional). Set when multiple device meshes are in use.
  repeated string global_devices = 6;

  // Required name which uniquely identifies this mesh in a program.
  string name = 3;

  // If true, ops on this mesh will use XLA SPMD.
  bool use_xla_spmd = 7;

  // The device string when the mesh is a single device mesh. If it is not
  // empty, then global_device_ids, local_device_ids, local_device and global
  // devices are all empty.
  string single_device = 8;
}
